{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ],
      "metadata": {
        "cellView": "form",
        "id": "IVkpU8HZ1eyz"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Use RunInference for Generative AI\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_generative_ai.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_generative_ai.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>\n"
      ],
      "metadata": {
        "id": "kH8SORNim8on"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "This notebook shows how to use the Apache Beam [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) transform for generative AI tasks. It uses a large language model (LLM) from the [Hugging Face Model Hub](https://huggingface.co/models).\n",
        "\n",
        "This notebook demonstrates the following steps:\n",
        "- Load and save a model from the Hugging Face Model Hub.\n",
        "- Use the PyTorch model handler for RunInference.\n",
        "\n",
        "For more information about using RunInference, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation."
      ],
      "metadata": {
        "id": "7N2XzwoA0k4L"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Install the Apache Beam SDK and dependencies\n",
        "\n",
        "Use the following code to install the Apache Beam Python SDK, PyTorch, and Transformers."
      ],
      "metadata": {
        "id": "nhf_lOeEsO1C"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wS9a3Y0oZ_l5"
      },
      "outputs": [],
      "source": [
        "!pip install apache_beam[interactive,gcp]==2.48.0\n",
        "!pip install torch\n",
        "!pip install transformers"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Use the following code to import dependencies\n",
        "\n",
        "**Important**: If an error occurs, restart your runtime."
      ],
      "metadata": {
        "id": "I7vMsFGW16bZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import apache_beam as beam\n",
        "from apache_beam.options.pipeline_options import PipelineOptions\n",
        "from apache_beam.ml.inference.base import PredictionResult\n",
        "from apache_beam.ml.inference.base import RunInference\n",
        "from apache_beam.ml.inference.pytorch_inference import make_tensor_model_fn\n",
        "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n",
        "import torch\n",
        "from transformers import AutoConfig\n",
        "from transformers import AutoModelForSeq2SeqLM\n",
        "from transformers import AutoTokenizer\n",
        "from transformers.tokenization_utils import PreTrainedTokenizer\n",
        "\n",
        "\n",
        "MAX_RESPONSE_TOKENS = 256\n",
        "\n",
        "model_name = \"google/flan-t5-small\"\n",
        "state_dict_path = \"saved_model\""
      ],
      "metadata": {
        "id": "uhbOYUzvbOSc"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Download and save the model\n",
        "This notebook uses the [auto classes](https://huggingface.co/docs/transformers/model_doc/auto) from Hugging Face to instantly load the model in memory. Later, the model is saved to the path defined previously."
      ],
      "metadata": {
        "id": "yRls3LmxswrC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model = AutoModelForSeq2SeqLM.from_pretrained(\n",
        "        model_name, torch_dtype=torch.bfloat16\n",
        "    )\n",
        "\n",
        "directory = os.path.dirname(state_dict_path)\n",
        "torch.save(model.state_dict(), state_dict_path)"
      ],
      "metadata": {
        "id": "PKhkiQFJe44n"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Define utility functions\n",
        "The input and output for the [`google/flan-t5-small`](https://huggingface.co/google/flan-t5-small) model are token tensors. These utility functions are used for the conversion of text to token tensors and then back to text.\n"
      ],
      "metadata": {
        "id": "7TSqb3l1s7F7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def to_tensors(input_text: str, tokenizer) -> torch.Tensor:\n",
        "    \"\"\"Encodes input text into token tensors.\n",
        "    Args:\n",
        "        input_text: Input text for the LLM model.\n",
        "        tokenizer: Tokenizer for the LLM model.\n",
        "    Returns: Tokenized input tokens.\n",
        "    \"\"\"\n",
        "    return tokenizer(input_text, return_tensors=\"pt\").input_ids[0]\n",
        "\n",
        "\n",
        "def from_tensors(result: PredictionResult, tokenizer) -> str:\n",
        "    \"\"\"Decodes output token tensors into text.\n",
        "    Args:\n",
        "        result: Prediction results from the RunInference transform.\n",
        "        tokenizer: Tokenizer for the LLM model.\n",
        "    Returns: The model's response as text.\n",
        "    \"\"\"\n",
        "    output_tokens = result.inference\n",
        "    return tokenizer.decode(output_tokens, skip_special_tokens=True)"
      ],
      "metadata": {
        "id": "OeTMbaLidnBe"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Load the tokenizer.\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "\n",
        "# Create an instance of the PyTorch model handler.\n",
        "model_handler = PytorchModelHandlerTensor(\n",
        "            state_dict_path=state_dict_path,\n",
        "            model_class=AutoModelForSeq2SeqLM.from_config,\n",
        "            model_params={\"config\": AutoConfig.from_pretrained(model_name)},\n",
        "            inference_fn=make_tensor_model_fn(\"generate\"),\n",
        "            )"
      ],
      "metadata": {
        "id": "77mOPlQR5Mev"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Run the Pipeline\n"
      ],
      "metadata": {
        "id": "y7IU3cYKtA3e"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "example = [\"translate English to Spanish: We are in New York City.\"]\n",
        "\n",
        "pipeline = beam.Pipeline(options=PipelineOptions(save_main_session=True,pickle_library=\"cloudpickle\"))\n",
        "\n",
        "with pipeline as p:\n",
        "  _ = (\n",
        "          p\n",
        "          | \"Create Examples\" >> beam.Create(example)\n",
        "          | \"To tensors\" >> beam.Map(to_tensors, tokenizer)\n",
        "          | \"RunInference\"\n",
        "            >> RunInference(\n",
        "                model_handler,\n",
        "                inference_args={\"max_new_tokens\": MAX_RESPONSE_TOKENS},\n",
        "            )\n",
        "          | \"From tensors\" >> beam.Map(from_tensors, tokenizer)\n",
        "          | \"Print\" >> beam.Map(print)\n",
        "      )"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 106
        },
        "id": "JUBrJEDYcjRG",
        "outputId": "e7e7f0a2-dc0d-427d-bdf8-047739bb3160"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "\n",
              "        if (typeof window.interactive_beam_jquery == 'undefined') {\n",
              "          var jqueryScript = document.createElement('script');\n",
              "          jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
              "          jqueryScript.type = 'text/javascript';\n",
              "          jqueryScript.onload = function() {\n",
              "            var datatableScript = document.createElement('script');\n",
              "            datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
              "            datatableScript.type = 'text/javascript';\n",
              "            datatableScript.onload = function() {\n",
              "              window.interactive_beam_jquery = jQuery.noConflict(true);\n",
              "              window.interactive_beam_jquery(document).ready(function($){\n",
              "                \n",
              "              });\n",
              "            }\n",
              "            document.head.appendChild(datatableScript);\n",
              "          };\n",
              "          document.head.appendChild(jqueryScript);\n",
              "        } else {\n",
              "          window.interactive_beam_jquery(document).ready(function($){\n",
              "            \n",
              "          });\n",
              "        }"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Estamos en Nueva York City.\n"
          ]
        }
      ]
    }
  ]
}
